{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "41f42f46-af1e-40ec-a736-e0a6ff814399",
   "metadata": {},
   "source": [
    "# 引入依赖"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8032c1ee-6786-4679-9660-31e3a080a39f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "from transformers import AutoConfig, AutoModelForMaskedLM\n",
    "from transformers import Trainer, TrainingArguments\n",
    "from transformers import DataCollatorForLanguageModeling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7d72759-934e-4157-b0eb-ebc72ec06cc6",
   "metadata": {},
   "source": [
    "# 模拟训练部分数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "44ef2308-315d-479c-a5dc-c95d1a7868e0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>2021华为matepadpro保护套12.6带笔槽10.8寸matepad11平板电脑壳V...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>食掌门 绞肉机800瓦大功率小型肉馅机电动灌肠机</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>WiFi收钱提示音响二维码语音播报器收款付款到账蓝牙小音箱色 炫酷黑 【蓝牙版】超长待机45天</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>同福茂 适用华为畅享20Pro摄像头玻璃镜片畅享Z前后置大小照相机头 【】畅享20Z后主摄像头</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>峰爵苹果13promax手机壳平安喜乐新款iPhone13虎年创意女款13pro全包本命年小...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>华为同款佳mp3mp4音乐播放器有屏学生随身听苹果风P5外放OTG可爱迷你 红 【带外放】可...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>源由 游戏键盘鼠标吃鸡oppoA59/A77/A83/A73/79/R15梦境手机壳保护皮套...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>QYHJD 适用荣耀9i背夹充电宝华为8/9/10青春版背夹电池nova青春版电源薄 荣耀1...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>ABS A4机柜文件夹盒PS柜配电箱威图柜WJ-1PS资料袋AE箱WJ-2文件夹 橙色带胶2...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>抖音快手热门歌曲音乐播放器迷你运动随身听mp4学生MP3 小怪兽 外响版+带卡含200首抖音...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>小米（MI）通用mp3MP4播放器迷你qq音乐学生运动跑步随身听电子书录音笔外放苹果 蓝色(...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>常用安全警示全套标示牌安全标识牌车间施工生产警告标志牌提示贴标语严禁烟火禁止吸烟有电危险标牌...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>机旺马华为畅享20plus手机壳磨砂超薄硅胶全包防摔欧美潮牌个性创意卡通图案男士款女网红软壳...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>【官方原装】荧阙 适用2021款iPad保护套10.2英英寸平板9代8电脑Air4/3带笔槽...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>胜枫 华为p50pro手机壳车载磁吸款p50镜头全包防摔保护套典藏版磨砂防指纹硬壳纯色极简保...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>韵果【小学初中高中生学习专用】蓝牙mp3英语听力随身听学生版小型便携式mp4mp5音乐播放器...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>苹果13/12/11/promax钢化膜13pro/12pro防窥iphone/x/xr/6...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>OPPOReno5手机壳男金属镜头全包reno5k磨砂OPPOreno5手机套5G誉科创SN...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>ITDK充电器插头超级快充数据充电线 适用于 【Max40W超级快充】充电器+1.5米5A数...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>鹿纯红米k40手机壳k40pro新款液态硅胶k40pro+全包防摔潮男女款软壳k40游戏增强...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>官方旗舰华强北智能watch6太空人s6代apple手表se顶配男士iwatch运动型iph...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                 text\n",
       "9   2021华为matepadpro保护套12.6带笔槽10.8寸matepad11平板电脑壳V...\n",
       "11                           食掌门 绞肉机800瓦大功率小型肉馅机电动灌肠机\n",
       "5     WiFi收钱提示音响二维码语音播报器收款付款到账蓝牙小音箱色 炫酷黑 【蓝牙版】超长待机45天\n",
       "1     同福茂 适用华为畅享20Pro摄像头玻璃镜片畅享Z前后置大小照相机头 【】畅享20Z后主摄像头\n",
       "4   峰爵苹果13promax手机壳平安喜乐新款iPhone13虎年创意女款13pro全包本命年小...\n",
       "8   华为同款佳mp3mp4音乐播放器有屏学生随身听苹果风P5外放OTG可爱迷你 红 【带外放】可...\n",
       "17  源由 游戏键盘鼠标吃鸡oppoA59/A77/A83/A73/79/R15梦境手机壳保护皮套...\n",
       "7   QYHJD 适用荣耀9i背夹充电宝华为8/9/10青春版背夹电池nova青春版电源薄 荣耀1...\n",
       "12  ABS A4机柜文件夹盒PS柜配电箱威图柜WJ-1PS资料袋AE箱WJ-2文件夹 橙色带胶2...\n",
       "19  抖音快手热门歌曲音乐播放器迷你运动随身听mp4学生MP3 小怪兽 外响版+带卡含200首抖音...\n",
       "2   小米（MI）通用mp3MP4播放器迷你qq音乐学生运动跑步随身听电子书录音笔外放苹果 蓝色(...\n",
       "6   常用安全警示全套标示牌安全标识牌车间施工生产警告标志牌提示贴标语严禁烟火禁止吸烟有电危险标牌...\n",
       "15  机旺马华为畅享20plus手机壳磨砂超薄硅胶全包防摔欧美潮牌个性创意卡通图案男士款女网红软壳...\n",
       "14  【官方原装】荧阙 适用2021款iPad保护套10.2英英寸平板9代8电脑Air4/3带笔槽...\n",
       "20  胜枫 华为p50pro手机壳车载磁吸款p50镜头全包防摔保护套典藏版磨砂防指纹硬壳纯色极简保...\n",
       "10  韵果【小学初中高中生学习专用】蓝牙mp3英语听力随身听学生版小型便携式mp4mp5音乐播放器...\n",
       "0   苹果13/12/11/promax钢化膜13pro/12pro防窥iphone/x/xr/6...\n",
       "16  OPPOReno5手机壳男金属镜头全包reno5k磨砂OPPOreno5手机套5G誉科创SN...\n",
       "18  ITDK充电器插头超级快充数据充电线 适用于 【Max40W超级快充】充电器+1.5米5A数...\n",
       "13  鹿纯红米k40手机壳k40pro新款液态硅胶k40pro+全包防摔潮男女款软壳k40游戏增强...\n",
       "3   官方旗舰华强北智能watch6太空人s6代apple手表se顶配男士iwatch运动型iph..."
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_list = [['苹果13/12/11/promax钢化膜13pro/12pro防窥iphone/x/xr/6/7皇迎 【黑边钻石防爆-高清透明膜】2片. iPhone X'],\n",
    " ['同福茂 适用华为畅享20Pro摄像头玻璃镜片畅享Z前后置大小照相机头 【】畅享20Z后主摄像头'],\n",
    " ['小米（MI）通用mp3MP4播放器迷你qq音乐学生运动跑步随身听电子书录音笔外放苹果 蓝色(带外放+可插内存卡) 8GB  王者套餐'],\n",
    " ['官方旗舰华强北智能watch6太空人s6代apple手表se顶配男士iwatch运动型iphone手 44MM/蓝高配版【太空人/长续航/双按键/高清 套餐二“送回环表带“'],\n",
    " ['峰爵苹果13promax手机壳平安喜乐新款iPhone13虎年创意女款13pro全包本命年小羊皮防摔 苹果13-红色-日富一日老虎'],\n",
    " ['WiFi收钱提示音响二维码语音播报器收款付款到账蓝牙小音箱色 炫酷黑 【蓝牙版】超长待机45天'],\n",
    " ['常用安全警示全套标示牌安全标识牌车间施工生产警告标志牌提示贴标语严禁烟火禁止吸烟有电危险标牌定制 禁止攀爬 20x30cm'],\n",
    " ['QYHJD 适用荣耀9i背夹充电宝华为8/9/10青春版背夹电池nova青春版电源薄 荣耀10青春版宝石蓝两万毫安 送耳机+数据线'],\n",
    " ['华为同款佳mp3mp4音乐播放器有屏学生随身听苹果风P5外放OTG可爱迷你 红 【带外放】可插卡 64GB  HiFi耳机套餐【鎹彩膜贴纸】'],\n",
    " ['2021华为matepadpro保护套12.6带笔槽10.8寸matepad11平板电脑壳V7硅胶p 【炭灰色+钢化膜-笔槽三折款】 华为平板 C5(10.4英寸)'],\n",
    " ['韵果【小学初中高中生学习专用】蓝牙mp3英语听力随身听学生版小型便携式mp4mp5音乐播放器听读 2英寸 触屏【+品质耳机+蓝牙功能】 16G蓝牙版(蓝牙功能+帮下载资料+发声词典)'],\n",
    " ['食掌门 绞肉机800瓦大功率小型肉馅机电动灌肠机'],\n",
    " ['ABS A4机柜文件夹盒PS柜配电箱威图柜WJ-1PS资料袋AE箱WJ-2文件夹 橙色带胶200个起'],\n",
    " ['鹿纯红米k40手机壳k40pro新款液态硅胶k40pro+全包防摔潮男女款软壳k40游戏增强版直边简 红米K40游戏增强版【海军蓝】*+钢化膜'],\n",
    " ['【官方原装】荧阙 适用2021款iPad保护套10.2英英寸平板9代8电脑Air4/3带笔槽ipad 【笔槽书本款】 梦想成真 付费软件APP iPad mini6(8.3英寸)'],\n",
    " ['机旺马华为畅享20plus手机壳磨砂超薄硅胶全包防摔欧美潮牌个性创意卡通图案男士款女网红软壳保护套 巧克联名 畅享20plus【加网红挂绳】+全屏膜'],\n",
    " ['OPPOReno5手机壳男金属镜头全包reno5k磨砂OPPOreno5手机套5G誉科创SN84 黑红色【单壳】刀锋 oppo Reno5'],\n",
    " ['源由 游戏键盘鼠标吃鸡oppoA59/A77/A83/A73/79/R15梦境手机壳保护皮套 升级版(黑色)键盘皮套+鼠标'],\n",
    " ['ITDK充电器插头超级快充数据充电线 适用于 【Max40W超级快充】充电器+1.5米5A数据线 华为Mate20 Pro/Mate20X 5g4g'],\n",
    " ['抖音快手热门歌曲音乐播放器迷你运动随身听mp4学生MP3 小怪兽 外响版+带卡含200首抖音+四件套'],\n",
    " ['胜枫 华为p50pro手机壳车载磁吸款p50镜头全包防摔保护套典藏版磨砂防指纹硬壳纯色极简保护套 P50【岩砂黑+车载支架】磁吸套装']]\n",
    "\n",
    "        \n",
    "data_df = pd.DataFrame(data_list, columns=['text'])\n",
    "data_df = data_df.sample(frac=1)\n",
    "data_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee2629e9-5683-408a-bc47-4c910b7e7611",
   "metadata": {},
   "source": [
    "# 定义预训练模型&创建dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cc804bdf-11c5-4d62-b4f8-b43871c59115",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "    定义BERT模型\n",
    "'''\n",
    "# model_checkpoint = 'chinese-roberta-wwm-ext'\n",
    "# tokenizer_checkpoint = 'chinese-roberta-wwm-ext'\n",
    "'''\n",
    "    本地引入模型\n",
    "'''\n",
    "model_checkpoint = 'D:/env/bert_model/hfl/chinese-roberta-wwm-ext'\n",
    "tokenizer_checkpoint = 'D:/env/bert_model/hfl/chinese-roberta-wwm-ext'\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3ccbef85-aae9-44d6-be5e-707db6162007",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 简单自定义dataset\n",
    "class PreTrainDataset(Dataset):\n",
    "    def __init__(self, data_list, tokenizer, max_seq_len):\n",
    "        super(PreTrainDataset, self).__init__()\n",
    "        self.data_list = data_list\n",
    "        self.len = len(data_list)\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_seq_len = max_seq_len\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        example = self.data_list[index]\n",
    "        data = self.tokenizer.encode_plus(example, return_token_type_ids=True, return_attention_mask=True, padding = 'max_length', max_length=self.max_seq_len)\n",
    "        return {'input_ids':  torch.tensor(data['input_ids'][:self.max_seq_len], dtype=torch.long), \n",
    "                'token_type_ids':  torch.tensor(data['token_type_ids'][:self.max_seq_len], dtype=torch.long), \n",
    "                'attention_mask':  torch.tensor(data['attention_mask'][:self.max_seq_len], dtype=torch.long)}\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.len\n",
    "\n",
    "# 生成dataset\n",
    "tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)\n",
    "test_size = int(len(data_df)*0.8)\n",
    "train_dataset = PreTrainDataset(data_df[:test_size]['text'].tolist(), tokenizer, 128)\n",
    "test_dataset = PreTrainDataset(data_df[test_size:]['text'].tolist(), tokenizer, 128)\n",
    "# 生成dataloader\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=8, num_workers=0)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=8, num_workers=0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7335e8a-9103-43b1-a04a-fe3e2d45cd18",
   "metadata": {},
   "source": [
    "# 加载定义模型&定义trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f20be05f-f05e-44e7-8a78-aa0c6ee1377d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "# 加载模型\n",
    "config = AutoConfig.from_pretrained(model_checkpoint)\n",
    "model = AutoModelForMaskedLM.from_config(config)\n",
    "\n",
    "model.to(device)\n",
    "# 定义trainer\n",
    "training_args = TrainingArguments(\n",
    "                          output_dir='pretrain_bert',\n",
    "                          overwrite_output_dir=True,\n",
    "                          do_train=True, \n",
    "                          do_eval=True,\n",
    "                          per_device_train_batch_size=32,\n",
    "                          per_device_eval_batch_size=100,\n",
    "                          evaluation_strategy='steps',\n",
    "                          logging_steps=10000,\n",
    "                          eval_steps = None,\n",
    "                          prediction_loss_only=True,\n",
    "                          learning_rate = 1e-5,\n",
    "                          weight_decay=0.01,\n",
    "                          adam_epsilon = 1e-8,\n",
    "                          max_grad_norm = 1.0,\n",
    "                          num_train_epochs = 10,\n",
    "                          save_steps = -1,\n",
    "                        push_to_hub=False\n",
    ")\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=test_dataset,\n",
    "    data_collator=data_collator,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a856b4d-c988-45e4-bd6d-90452bb997fa",
   "metadata": {},
   "source": [
    "# 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea3ab395-06fa-4a35-bd16-b579f4cc5681",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()\n",
    "trainer.save_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
